from dataclasses import dataclass
import os
from typing import Union
import json

from .config import ConfigScriptRewriter

from .prompt import GEN_ACT_SYSTEM, GEN_ACT_USER, GEN_DIAG_SYSTEM, GEN_DIAG_USER, GEN_SCENE_SYSTEM, GEN_SCENE_USER, SUMMARIZE_SYSTEM, SUMMARIZE_USER, SYSTEM_FORMAT_PROMPT, USER_FORMAT_PROMPT, extract_action_changes, extract_dialogue_changes, extract_episode_summary, extract_three_act

from .models import CallModel

from .utils import connect_logging, get_unique_id, save_json, read_json, save_file, count_words


class ScriptRewriter:

    def __init__(self, config: ConfigScriptRewriter):
        self.debug_list = []
        self.log_file = None

        # 【检查config参数并初始化】

        # 文件读写准备
        if config.unique_id == None:  # 如果没有指定unique_id，则自动生成（默认为获取当前时间）
            config.unique_id = get_unique_id()
        if config.print_log == True:
            self.log_file = connect_logging(save_dir='./logs', name="ScriptRewriter",
                                            # unique_id=config.unique_id
                                            )  # 链接本次操作的保存日志
        # 实际输出路径
        if config.output_save_dir.endswith('/') == False:
            config.output_save_dir += '/'
        config.output_save_dir += f'{config.unique_id}_{config.output_basename}/'

        # 【导入config】
        self.config = config
        self.model_config = config.model_config
        # self.model_name = config.model_name
        self.model = self.set_model()  # CallModel(config)

        # 断点续传需要记录的信息
        self.output_save_dir = config.output_save_dir
        self.basename = config.output_basename
        self.unique_id = config.unique_id

        # 【生成相关信息】
        # self.current_gen_script = None
        # self.uptonow_summary = None

        self.rewriter_elements = []  # dict list
        self.current_gen_number = 1  # 当前要生成的新集
        if config.load_data == True:  # 如果要断点续传
            self.load_data()

        # 【其他参数】
        self.max_retry = 5

    def load_data(self, load_to_number: int = -1):  # 从上次保存的文件中读取数据
        if load_to_number == 0:
            self.info(f"load_data(): load_to_number=0, do nothing.")
            return

        path = self.output_save_dir + 'json_data/'
        try:
            rewriter_elements = read_json(path, 'rewriter_elements.json')
            if rewriter_elements:
                self.rewriter_elements = rewriter_elements
                self.current_gen_number = len(rewriter_elements) + 1
                self.info(
                    f"Successfully loaded 'rewriter_elements {1}-{self.current_gen_number-1}' from '{path}'.")
            else:
                self.info(f"Successfully loaded 'rewriter_elements 0")
            # if load_to_number >= 1:
            #     while len(self.chapter_elements) > load_to_number:
            #         self.backspace_data()
        except FileNotFoundError:
            self.info(f"Loaded no data from '{path}'.")

    def backspace_data(self):
        if self.current_gen_number > 1:
            self.info(
                f"Successfully backspace 'rewriter_elements {1}-{self.current_gen_number-1}' to 'rewriter_elements {1}-{self.current_gen_number-2}'")
            self.chapter_elements = self.chapter_elements[:-1]
            self.current_gen_number -= 1
        else:
            self.info("No rewriter_elements data for backspace.")

    def clear_data(self):  # 清空生成过程中保存的数据
        self.rewriter_elements = []
        self.current_gen_number = 1

    def set_model(self, **kwargs):
        config = self.model_config
        config.model_name = kwargs.get(
            "model_name", self.model_config.model_name)
        return CallModel(config)

    def add_debug(self, text):
        if text != None and text != "":
            print(text)
            self.debug_list.append(text)
            return self.debug_list

    def info(self, text):
        if text != None and text != "":
            self.add_debug(f"    [Info] {text}")
            return self.debug_list

    def warning(self, text):
        if text != None and text != "":
            self.add_debug(f"    [Warning] {text}")
            return self.debug_list

    def scenes_rewriter(self, raw_script, main_conflict, scene_settings, **generate_kwargs):
        # 基于冲突和场景设置改写剧本
        system_prompt = GEN_SCENE_SYSTEM
        user_input = GEN_SCENE_USER.format(
            raw_script=raw_script, main_conflict=main_conflict, scene_settings=scene_settings)
        for attempt in range(self.max_retry+1):
            response = self.model.get_response(system_prompt, user_input,
                                               task=f'scenes_rewriter {self.current_gen_number}',
                                               save_dir=self.output_save_dir, **generate_kwargs)
            extracted_info = extract_three_act(response)
            if extracted_info.get("rewritten_script"):
                break
            elif attempt < self.max_retry:
                self.warning(
                    f"scenes_rewriter(): Failed extract_json! Retry {attempt+1}/{self.max_retry}...")

        # self.info("scenes_rewriter() done.")
        return extracted_info.get("rewritten_script", "")

    def action_rewriter(self, raw_script, character_psychology, character_relations, action_limit: str, **generate_kwargs):
        # 基于人物心理变化和人物关系变化将剧本中的动作和镜头描写进行润色和调整，使之更符合人物心理变化和人物关系变化
        # 如果action_limit为None，则自行根据相关信息润色动作和镜头描写，使之符合人物心理变化和人物关系变化；
        # 否则需要在与人物心理变化和人物关系变化保持一致的情况下将相关描写的篇幅趋近于action_limit的程度
        # action_limit主要有几个数量级的描述：极少，少，中等，多，较多
        # 输出是修改剧集内动作和镜头描写的依据和修改后的剧本
        system_prompt = GEN_ACT_SYSTEM
        user_input = GEN_ACT_USER.format(raw_script=raw_script,
                                         character_psychology=character_psychology,
                                         character_relations=character_relations,
                                         action_limit=action_limit)
        for attempt in range(self.max_retry+1):
            response = self.model.get_response(system_prompt, user_input,
                                               task=f'action_rewriter {self.current_gen_number}',
                                               save_dir=self.output_save_dir, **generate_kwargs)
            extracted_info = extract_action_changes(response)
            if extracted_info.get("rewritten_script"):
                break
            elif attempt < self.max_retry:
                self.warning(
                    f"action_rewriter(): Failed extract_json! Retry {attempt+1}/{self.max_retry}...")

        # self.info("action_rewriter() done.")
        return extracted_info.get("rewritten_script", "")

    def dialogue_rewriter(self, raw_script, character_psychology, character_relations,
                          dialogues_limit: tuple,  script_info: dict = None,
                          **generate_kwargs):
        # 基于人物心理变化和人物关系变化将剧本中的人物对话进行合并或扩写，使得对话数量在dialogues_limit的范围内
        # 如果dialogues_limit为None，则自己根据相关信息润色对话，使之符合人物心理变化和人物关系变化
        # 输出是修改对话的依据和修改后的剧本
        system_prompt = GEN_DIAG_SYSTEM
        user_input = GEN_DIAG_USER.format(raw_script=raw_script,
                                          character_psychology=character_psychology,
                                          character_relations=character_relations,
                                          dialogues_limit=dialogues_limit)
        for attempt in range(self.max_retry+1):
            response = self.model.get_response(system_prompt, user_input,
                                               task=f'dialogue_rewriter {self.current_gen_number}',
                                               save_dir=self.output_save_dir, **generate_kwargs)
            extracted_info = extract_dialogue_changes(response)
            if extracted_info.get("rewritten_script"):
                break
            elif attempt < self.max_retry:
                self.warning(
                    f"dialogue_rewriter(): Failed extract_json! Retry {attempt+1}/{self.max_retry}...")

        # self.info("dialogue_rewriter() done.")
        return extracted_info.get("rewritten_script", "")

    def format_rewriter(self, raw_script, **generate_kwargs):
        # 基于人物心理变化和人物关系变化将剧本中的人物对话进行合并或扩写，使得对话数量在dialogues_limit的范围内
        # 如果dialogues_limit为None，则自己根据相关信息润色对话，使之符合人物心理变化和人物关系变化
        # 输出是修改对话的依据和修改后的剧本
        system_prompt = SYSTEM_FORMAT_PROMPT
        user_input = USER_FORMAT_PROMPT.format(current_episode=raw_script)
        for attempt in range(self.max_retry+1):
            response = self.model.get_response(system_prompt, user_input,
                                               task=f'format_rewriter {self.current_gen_number}',
                                               save_dir=self.output_save_dir, **generate_kwargs)
            response = response.replace('### Revised Script:', '').strip()
            response = response.replace('### REVISED SCRIPT:', '').strip()
            response = response.replace('Revised Script:', '').strip()
            response = response.replace('REVISED SCRIPT:', '').strip()
            # extracted_info = extract_dialogue_changes(response)
            if response:
                break
            elif attempt < self.max_retry:
                self.warning(
                    f"format_rewriter(): Failed extract_json! Retry {attempt+1}/{self.max_retry}...")

        return response

    def summarize_script(self, rewritten_script, episode_summaries: str, **generate_kwargs):
        # 基于以前的情节总结，将给定的当前改好的剧集总结为当前情节，并输出到目前为止剧本的情节内容和大致剧情作为前情提要
        # 前情提要用于生成后续的剧集内容（给章节提取器用）
        # episode_summaries_str = ''.join([f'Plot {i+1}:\n{plot}\n\n' for i, plot in enumerate(episode_summaries)])
        system_prompt = SUMMARIZE_SYSTEM
        user_input = SUMMARIZE_USER.format(rewritten_script=rewritten_script,
                                           episode_summaries=episode_summaries, **generate_kwargs)
        for attempt in range(self.max_retry+1):
            response = self.model.get_response(system_prompt, user_input,
                                               task=f'summarize_script {self.current_gen_number}',
                                               save_dir=self.output_save_dir)
            episode_summary, uptonow_summary = extract_episode_summary(
                response)
            if episode_summary and uptonow_summary:
                break
            elif attempt < self.max_retry:
                self.warning(
                    f"summarize_script(): Failed extract_json! Retry {attempt+1}/{self.max_retry}...")

        # self.info("summarize_script() done.")
        return episode_summary, uptonow_summary

    def rewrite_pipeline(self, global_elements: dict, chapter_elements: list, dialogues_limit: tuple = None, action_limit: str = "moderate", **generate_kwargs):

        if generate_kwargs.get('temperature') == None:  # 温度默认使用0
            generate_kwargs['temperature'] = 0

        # 断点续传数据崩坏时可能会遇到的极端情况（正常应该先提取章节信息再改写剧本，所以左边长度应该等于右边+1）
        if len(chapter_elements) > len(self.rewriter_elements) + 1:  # 如果章节提取器比改写器多
            self.warning(
                'chapter_elements is longer than rewriter_elements + 1, please use ChapterExtractor.backspace_data() to fix it')
            return {'rewriter_elements': self.rewriter_elements, 'message': 'Fail'}

        # 获取全局提取器要素

        # 获取章节提取器要素
        try:
            raw_script = chapter_elements[self.current_gen_number -
                                          1].get('raw_script', {})
            chapter_info = chapter_elements[self.current_gen_number-1].get(
                'chapter_info', {})

            main_conflict = chapter_info.get("main_conflict", "")
            scene_settings = chapter_info.get("scene_settings", "")
            character_psychology = chapter_info.get("character_psychology", "")
            character_relations = chapter_info.get("character_relations", "")
        except Exception as e:
            self.warning(
                f'Failed to get the chapter_elements. Error message: {str(e)}')
            return {'rewriter_elements': self.rewriter_elements, 'message': 'Fail'}

        # 步骤 1: 基于主要冲突和场景设置改写剧本
        scenes_rewritten_script = ""
        try:
            scenes_rewritten_script = self.scenes_rewriter(raw_script,
                                                           main_conflict, scene_settings, **generate_kwargs)
            if not scenes_rewritten_script:
                raise Exception("Fail")
            self.info(
                f"Successfully rewrite scenes. Total words = {count_words(scenes_rewritten_script)}")
        except Exception as e:
            self.warning(f'Failed to rewrite scenes. Error message: {str(e)}')

        scenes_rewritten_script = raw_script
        # 步骤 2: 改写动作和镜头描述
        action_rewritten_script = ""
        try:
            action_rewritten_script = self.action_rewriter(scenes_rewritten_script,
                                                           character_psychology, character_relations, action_limit, **generate_kwargs)
            if not action_rewritten_script:
                raise Exception("Fail")
            self.info(
                f"Successfully rewrite actions. Total words = {count_words(action_rewritten_script)}")
        except Exception as e:
            self.warning(f'Failed to rewrite actions. Error message: {str(e)}')

        # 步骤 3: 基于角色心理和关系改写对话
        dialogues_rewritten_script = ""
        try:
            dialogues_rewritten_script = self.dialogue_rewriter(action_rewritten_script,
                                                                character_psychology, character_relations, dialogues_limit, **generate_kwargs)
            if not dialogues_rewritten_script:
                raise Exception("Fail")
            self.info(
                f"Successfully rewrite dialogues. Total words = {count_words(dialogues_rewritten_script)}")
        except Exception as e:
            self.warning(
                f'Failed to rewrite dialogues. Error message: {str(e)}')

        # 步骤 4: 改写剧本格式内容规范
        format_rewritten_script = ""
        try:
            format_rewritten_script = self.format_rewriter(dialogues_rewritten_script,
                                                           **generate_kwargs)
            if not format_rewritten_script:
                raise Exception("Fail")
            self.info(
                f"Successfully rewrite format. Total words = {count_words(format_rewritten_script)}")
        except Exception as e:
            self.warning(f'Failed to rewrite format. Error message: {str(e)}')

        # 得到最终剧本
        final_script = format_rewritten_script

        # 步骤 5: 总结剧本
        current_summary, uptonow_summary = "", ""
        try:
            current_summary, uptonow_summary = self.summarize_script(
                final_script, self.get_episode_summary_uptonow())
            if not current_summary or not uptonow_summary:
                raise Exception("Fail")
            self.info(
                f"Successfully summarize script. Total words = {count_words(current_summary), count_words(uptonow_summary)}")
        except Exception as e:
            self.warning(
                f'Failed to summarize script. Error message: {str(e)}')

        # 返回最终改编的剧本及其总结
        if final_script and current_summary and uptonow_summary:
            elements = {
                "final_script": final_script,  # 本集的最终剧本
                "uptonow_summary": uptonow_summary,  # 迄今为止所有集摘要
                "episode_summary": current_summary,  # 当前集摘要
                "rewritten_scripts": {  # 剧本改写的中间结果
                    "scenes_rewritten_script": scenes_rewritten_script,  # 1.改写了场景的剧本
                    "dialogues_rewritten_script": dialogues_rewritten_script,  # 2.改写了对话的剧本
                    "action_rewritten_script": action_rewritten_script,  # 3.改写了动作的剧本
                    "format_rewritten_script": format_rewritten_script,  # 4.格式化后的剧本
                },
            }
            self.rewriter_elements.append(elements)
            self.save_data_all(self.rewriter_elements)

            self.info(("Successfully rewrite_pipeline. Data has been saved."))

            self.current_gen_number += 1
            return {'rewriter_elements': self.rewriter_elements, 'message': 'Success'}
        else:
            return {'rewriter_elements': self.rewriter_elements, 'message': 'Fail'}

    def get_episode_summary_uptonow(self):
        episode_summary_uptonow = ''.join(
            [f"Plot summary of episode {i+1} : {elements['episode_summary']}\n\n" for i,
                elements in enumerate(self.rewriter_elements)]
        )
        return episode_summary_uptonow

    def get_episode_uptonow(self):
        episode_uptonow = ''.join(
            [f"- [Episode {i+1}] -\n\n{elements['final_script']}\n\n" for i,
                elements in enumerate(self.rewriter_elements)]
        )
        return episode_uptonow

    # 【数据储存和展示】：
    def save_data_all(self, rewriter_elements):
        path = self.output_save_dir + 'json_data/'
        save_json(rewriter_elements, path, 'rewriter_elements')
        self.info(
            f"Data 'rewriter_elements.json'(Total length={len(rewriter_elements)} has been saved to '{path}'.")
        cnt = len(rewriter_elements)

        final_scripts = f"""Final Scripts of '{self.basename}':\n
- [SYNOPISIS] -\n{self.rewriter_elements[-1]['uptonow_summary']}\n\n"""
        full_text_info = f"Rewritten Scripts Info of '{self.basename}':\n"

        for i in range(cnt):
            elements = rewriter_elements[i]
            rewritten_scripts = elements.get('rewritten_scripts', {})
            scenes_rewritten_script = rewritten_scripts.get(
                'scenes_rewritten_script', "")
            action_rewritten_script = rewritten_scripts.get(
                'action_rewritten_script', "")
            dialogues_rewritten_script = rewritten_scripts.get(
                'dialogues_rewritten_script', "")
            format_rewritten_script = rewritten_scripts.get(
                'format_rewritten_script', "")

            final_scripts += f""" - [Episode {i+1}] -\n{elements.get('final_script', "")}\n\n"""

            full_text_info += f"""
[Rewritten script {i+1}]\n
    [Up To Now Summary] : \n        {elements.get('uptonow_summary', "")}\n
    [Up To Now Summary Feedback] : ...\n
    [Current summary] : \n        {elements.get('episode_summary', "")}\n
    [Current summary Feedback] : ...\n
    [The 1st Scenes Rewritten episode (Rewritten Scenes)] :\n        {""}\n
    [The 1st Scenes Rewritten episode Feedback] : ...\n
    [The 2st Actions Rewritten episode (Rewritten Actions)] : \n        {action_rewritten_script}\n
    [The 2st Actions Rewritten episode Feedback] : ...
    [The 3st Dialogues Rewritten episode (Rewritten Dialogues)] :\n        {dialogues_rewritten_script}\n
    [The 3st Dialogues Rewritten episode Feedback] : ...
    [The 4st(Final) Format Rewritten episode (Rewritten Actions)] : \n        {format_rewritten_script}\n
    [The 4st(Final) Format Rewritten episode Feedback] : ...
"""
        save_file(final_scripts, self.output_save_dir +
                  'docx_data/', 'final_scripts', 'docx')
        save_file(full_text_info, self.output_save_dir +
                  'docx_data/', 'rewritten_scripts', 'docx')
